library(tidyverse)
library(ComplexHeatmap)
library(circlize)
library(pheatmap)
library(corrplot)
library(ggplot2)
library(reshape2)
library(grid)
library(dplyr)


# Load the tables before running the R script.

#------------------- start -----------------------

# col_order
col_order = c(
  "Bf_21", "Bf_31", "Bf_33", "Bf_16", "Bf_30", "Bf_35",
  "Bf_6",
  "Bf_17",
  "Bf_28", "Bf_15",
  "Bf_18",
  "Bf_34",
  "Bf_13", "Bf_3", "Bf_8",
  "Bf_39", "Bf_36", "Bf_9", "Bf_22", "Bf_7", "Bf_1",
  "Bf_26", "Bf_24", "Bf_25",
  "Bf_5",  "Bf_37", "Bf_20", "Bf_27", "Bf_10",
  "Bf_14", "Bf_32",
  "Bf_29",
  "Bf_11", "Bf_38", "Bf_0", "Bf_23", "Bf_2", "Bf_19", "Bf_4", "Bf_12"
)
column_groups = c("Non-neuron", "Neuron", "Neuron", "Non-neuron", "Non-neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", 
                  "Neuron", "Non-neuron", "Non-neuron", "Neuron", "Neuron", "Neuron", "Non-neuron", "Neuron", "Neuron", "Neuron", 
                  "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron", 
                  "Neuron", "Neuron", "Non-neuron", "Neuron", "Non-neuron", "Neuron", "Neuron", "Neuron", "Neuron", "Neuron")

top_anno <- HeatmapAnnotation(
  type = column_groups,
  col = list(type = c("Neuron" = "orchid", "Non-neuron" = "lightgreen")),
  annotation_name_gp = gpar(fontsize = 8),
  simple_anno_size = unit(2.5, "mm")
)


for (i in seq_along(df_filtered_list)) {
  
df <- df_filtered_list[[i]]

# sort 
sort_matrix_columnwise <- function(df, threshold = 0.1, col_order = NULL) {
  df$orig_order <- 1:nrow(df) 
  result <- data.frame()
  to_sort <- df
  
  if (is.null(col_order)) {
    col_order <- colnames(df)[-ncol(df)]
  }
  
  for (col in col_order) {
    to_sort <- to_sort[order(to_sort[[col]], decreasing = TRUE), ]
    
    fix_now <- to_sort[[col]] > threshold
    
    if (any(fix_now)) {
      result <- rbind(result, to_sort[fix_now, ])
      to_sort <- to_sort[!fix_now, ]
    }
    
    if (nrow(to_sort) == 0) break
  }
  
  if (nrow(to_sort) > 0) {
    result <- rbind(result, to_sort)
  }
  
  result$orig_order <- NULL
  return(result)
}

df <- sort_matrix_columnwise(df, threshold = 0.1, col_order = col_order)

# generate heatmap
ht = Heatmap(as.matrix(df),
        #top_annotation = c(top_anno, top_anno3),
        col = colorRamp2(seq(0, 1, length.out = 25), (hcl.colors(25, "Oslo"))),
        name = Legend_name[[i]],
        column_names_rot = -90,
        column_order = col_order,
        row_names_side = "right",
        row_order = rownames(df),
        show_row_dend = TRUE,
        width = ncol(df) * unit(5, "mm"),
        height = nrow(df) * unit(3.5, "mm"),
        rect_gp = gpar(col = "#aaa", lwd = 1),
        row_names_gp = gpar(fontsize = 10),
        column_names_gp = gpar(fontsize = 10),
        show_heatmap_legend = TRUE,
        cell_fun = function(j, i, x, y, width, height, fill) {
          if (df[i, j] > 0.1) {
            grid.points(x, y, pch = 16, size = unit(1.5, "mm"), gp = gpar(col = "#ff3"))
          }
        },
        border = TRUE)

# export
ht_drawn = draw(ht,
                show_heatmap_legend = FALSE,
                show_annotation_legend = FALSE)


w = ComplexHeatmap:::width(ht_drawn)
h = ComplexHeatmap:::height(ht_drawn)

w_inch = convertWidth(w, "inch", valueOnly = TRUE)
h_inch = convertHeight(h, "inch", valueOnly = TRUE)

path <- paste0("./Mouse_SubClass vs Bf snRNAseq_r1_",Legend_name[[i]],".pdf")



pdf(path, width = w_inch*1.1, height = h_inch*1.1)
draw(ht,
     show_heatmap_legend = FALSE,
     show_annotation_legend = FALSE)
dev.off()

}


dev.off()